{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from src.data.components.helsinki_dataset import Dataset, load_dataset, pad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from omegaconf import DictConfig, OmegaConf\n",
    "\n",
    "config = dict()\n",
    "config[\"train_set\"] = \"test.txt\"\n",
    "config[\n",
    "    \"datadir\"\n",
    "] = \"/Users/lukas/Desktop/projects/MIT/prosody/prosody/repositories/helsinki-prosody/data/\"\n",
    "config[\"fraction_of_train_data\"] = 1\n",
    "config[\"nclasses\"] = 2\n",
    "config[\"shuffle_sentences\"] = True\n",
    "config[\"sorted_batches\"] = True\n",
    "config[\"model\"] = \"gpt2\"\n",
    "config[\"log_values\"] = False\n",
    "config[\"invalid_set_to\"] = False\n",
    "config[\"mask_invalid_grads\"] = True\n",
    "\n",
    "config = OmegaConf.create(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create splits\n",
    "splits, tag_to_index, index_to_tag, vocab = load_dataset(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_to_embid = None\n",
    "\n",
    "train_dataset = Dataset(splits[\"train\"], tag_to_index, config, word_to_embid)\n",
    "eval_dataset = Dataset(splits[\"dev\"], tag_to_index, config, word_to_embid)\n",
    "test_dataset = Dataset(splits[\"test\"], tag_to_index, config, word_to_embid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=pad\n",
    ")\n",
    "\n",
    "len(dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import BertModel, BertTokenizer, GPT2Tokenizer\n",
    "\n",
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
    "\n",
    "from src.models.helsinki_models import BertRegression\n",
    "from src.data.components.helsinki_dataset import weighted_mse_loss\n",
    "\n",
    "model = BertRegression(\"mps\", config).to(\"mps\")\n",
    "\n",
    "device = \"mps\"\n",
    "\n",
    "# criterion = weighted_mse_loss\n",
    "criterion = torch.nn.MSELoss().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create queue with at most 20 elements\n",
    "from collections import deque\n",
    "from tqdm import tqdm\n",
    "\n",
    "queue = deque(maxlen=50)\n",
    "\n",
    "total_iterations = len(dataloader)\n",
    "pbar = tqdm(\n",
    "    total=total_iterations,\n",
    "    desc=\"Loss: N/A\",\n",
    "    bar_format=\"{desc} |{bar}| {percentage:3.0f}% {r_bar}\",\n",
    ")\n",
    "\n",
    "for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):\n",
    "    words, x, is_main_piece, tags, y, seqlens, values, _ = batch\n",
    "\n",
    "    # REGRESSION\n",
    "    # optimizer.zero_grad()\n",
    "    # x = x.to(device)\n",
    "    # values = values.to(device)\n",
    "    # predictions, true = model(x, values)\n",
    "    # loss = criterion(predictions.to(device), true.float().to(device))\n",
    "    # loss.backward()\n",
    "    # optimizer.step()\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    x = x.to(device)\n",
    "    y = y.to(device)\n",
    "    logits, y, _ = model(x, y)  # logits: (N, T, VOCAB), y: (N, T)\n",
    "    logits = logits.view(-1, logits.shape[-1])  # (N*T, VOCAB)\n",
    "    y = y.view(-1)  # (N*T,)\n",
    "    loss = criterion(logits.to(device), y.to(device))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    queue.append(loss.item())\n",
    "\n",
    "    if (i + 1) % 50 == 0:\n",
    "        print(f\"Avg loss: {sum(queue) / len(queue)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "prosody",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
